"""Helper functions for the Better Thermostat component."""

import logging

from homeassistant.components.climate.const import HVACAction, HVACMode

from custom_components.better_thermostat.utils.const import (
    CalibrationMode,
    CONF_PROTECT_OVERHEATING,
)

from custom_components.better_thermostat.utils.helpers import (
    convert_to_float,
    round_by_step,
    heating_power_valve_position,
)

from custom_components.better_thermostat.model_fixes.model_quirks import (
    fix_local_calibration,
    fix_target_temperature_calibration,
)

from custom_components.better_thermostat.utils.mpc import (
    MpcInput,
    MpcParams,
    build_mpc_key,
    compute_mpc,
)

_LOGGER = logging.getLogger(__name__)


def _supports_direct_valve_control(self, entity_id: str) -> bool:
    """Return True if the TRV supports writing a valve percentage."""

    trv_data = self.real_trvs.get(entity_id) or {}
    if trv_data.get("valve_position_entity"):
        return True
    quirks = trv_data.get("model_quirks")
    return bool(getattr(quirks, "override_set_valve", None))


def _build_mpc_params(self, entity_id: str) -> MpcParams:
    """Build MPC parameters based on advanced TRV settings."""

    adv = (self.real_trvs.get(entity_id, {}) or {}).get("advanced", {}) or {}
    params = MpcParams()
    overrides = {
        "mpc_thermal_gain": float,
        "mpc_loss_coeff": float,
        "mpc_control_penalty": float,
        "mpc_change_penalty": float,
        "mpc_adapt": bool,
        "mpc_gain_min": float,
        "mpc_gain_max": float,
        "mpc_loss_min": float,
        "mpc_loss_max": float,
        "mpc_adapt_alpha": float,
        "percent_hysteresis_pts": float,
        "min_update_interval_s": float,
    }

    for key, caster in overrides.items():
        if key not in adv:
            continue
        value = adv.get(key)
        if value is None:
            continue
        if caster is bool:
            coerced = bool(value)
        elif caster is int:
            try:
                coerced = int(float(value))
            except (TypeError, ValueError):
                continue
        else:
            try:
                coerced = caster(value)
            except (TypeError, ValueError):
                continue
        if hasattr(params, key):
            setattr(params, key, coerced)

    return params


def _compute_mpc_balance(self, entity_id: str):
    """Run the MPC balance algorithm for calibration purposes."""

    trv_state = self.real_trvs.get(entity_id)
    if trv_state is None:
        return None, False

    if self.bt_target_temp is None or self.cur_temp is None:
        trv_state["calibration_balance"] = None
        return None, False

    if getattr(self, "window_open", False) is True:
        trv_state["calibration_balance"] = None
        return None, False

    hvac_mode = getattr(self, "bt_hvac_mode", None)
    if hvac_mode == HVACMode.OFF:
        trv_state["calibration_balance"] = None
        return None, False

    params = _build_mpc_params(self, entity_id)

    try:
        mpc_output = compute_mpc(
            MpcInput(
                key=build_mpc_key(self, entity_id),
                target_temp_C=self.bt_target_temp,
                current_temp_C=self.cur_temp,
                trv_temp_C=trv_state.get("current_temperature"),
                tolerance_K=float(getattr(self, "tolerance", 0.0) or 0.0),
                temp_slope_K_per_min=getattr(self, "temp_slope", None),
                window_open=getattr(self, "window_open", False),
                heating_allowed=True,
                bt_name=getattr(self, "device_name", None),
                entity_id=entity_id,
            ),
            params,
        )
    except (ValueError, TypeError, ZeroDivisionError) as err:
        _LOGGER.debug(
            "better_thermostat %s: MPC calibration compute failed for %s: %s",
            getattr(self, "device_name", "unknown"),
            entity_id,
            err,
        )
        trv_state["calibration_balance"] = None
        return None, False

    if mpc_output is None:
        trv_state["calibration_balance"] = None
        return None, False

    supports_valve = _supports_direct_valve_control(self, entity_id)
    trv_state["calibration_balance"] = {
        "valve_percent": mpc_output.valve_percent,
        "flow_cap_K": mpc_output.flow_cap_K,
        "setpoint_eff_C": mpc_output.setpoint_eff_C,
        "apply_valve": supports_valve,
        "debug": getattr(mpc_output, "debug", None),
    }

    _schedule_mpc = getattr(self, "_schedule_save_mpc_states", None)
    if callable(_schedule_mpc):
        _schedule_mpc()

    return mpc_output, supports_valve


def calculate_calibration_local(self, entity_id) -> float | None:
    """Calculate local delta to adjust the setpoint of the TRV based on the air temperature of the external sensor.

    This calibration is for devices with local calibration option, it syncs the current temperature of the TRV to the target temperature of
    the external sensor.

    Parameters
    ----------
    self :
            self instance of better_thermostat

    Returns
    -------
    float
            new local calibration delta
    """
    _context = "_calculate_calibration_local()"

    def _convert_to_float(value):
        return convert_to_float(value, self.name, _context)

    if self.cur_temp is None or self.bt_target_temp is None:
        return None

    # Add tolerance check
    _cur_external_temp = self.cur_temp
    _cur_target_temp = self.bt_target_temp
    _within_tolerance = _cur_external_temp >= (
        _cur_target_temp - self.tolerance
    ) and _cur_external_temp <= (_cur_target_temp + self.tolerance)

    _calibration_mode = self.real_trvs[entity_id]["advanced"].get(
        "calibration_mode", CalibrationMode.DEFAULT
    )

    if _within_tolerance:
        # When within tolerance, don't adjust calibration but keep MPC valve data fresh
        if _calibration_mode == CalibrationMode.MPC_CALIBRATION:
            _compute_mpc_balance(self, entity_id)
        else:
            self.real_trvs[entity_id].pop("calibration_balance", None)
        return self.real_trvs[entity_id]["last_calibration"]

    _cur_trv_temp_s = self.real_trvs[entity_id]["current_temperature"]
    _calibration_step = self.real_trvs[entity_id]["local_calibration_step"]
    _calibration_step = _convert_to_float(_calibration_step)
    _cur_trv_temp_f = _convert_to_float(_cur_trv_temp_s)
    _current_trv_calibration = _convert_to_float(
        self.real_trvs[entity_id]["last_calibration"]
    )

    if (
        _current_trv_calibration is None
        or _cur_external_temp is None
        or _cur_trv_temp_f is None
        or _calibration_step is None
    ):
        _LOGGER.warning(
            "better thermostat %s: %s Could not calculate local calibration in %s: "
            "trv_calibration: %s, trv_temp: %s, external_temp: %s calibration_step: %s",
            self.device_name,
            entity_id,
            _context,
            _current_trv_calibration,
            _cur_trv_temp_f,
            _cur_external_temp,
            _calibration_step,
        )
        return None

    _cur_external_temp = float(_cur_external_temp)
    _cur_target_temp = float(_cur_target_temp)
    _cur_trv_temp_f = float(_cur_trv_temp_f)
    _current_trv_calibration = float(_current_trv_calibration)
    _calibration_step = float(_calibration_step)

    _new_trv_calibration = (
        _cur_external_temp - _cur_trv_temp_f
    ) + _current_trv_calibration

    _mpc_result = None
    _mpc_use_valve = False
    if _calibration_mode == CalibrationMode.MPC_CALIBRATION:
        _mpc_result, _mpc_use_valve = _compute_mpc_balance(self, entity_id)
        if _mpc_use_valve:
            _new_trv_calibration = _current_trv_calibration
        elif _mpc_result is not None:
            _desired_trv_setpoint: float | None = None
            _mpc_setpoint = getattr(_mpc_result, "setpoint_eff_C", None)
            if isinstance(_mpc_setpoint, (int, float)):
                _desired_trv_setpoint = float(_mpc_setpoint)
            else:
                _mpc_percent = getattr(_mpc_result, "valve_percent", None)
                if isinstance(_mpc_percent, (int, float)):
                    _max_temp = _convert_to_float(self.real_trvs[entity_id]["max_temp"])
                    if _max_temp is not None:
                        _valve_fraction = max(
                            0.0, min(1.0, float(_mpc_percent) / 100.0)
                        )
                        _desired_trv_setpoint = _cur_trv_temp_f + (
                            (float(_max_temp) - _cur_trv_temp_f) * _valve_fraction
                        )

            if _desired_trv_setpoint is not None:
                _new_trv_calibration = _current_trv_calibration - (
                    _desired_trv_setpoint - _cur_target_temp
                )
            else:
                try:
                    _flow_cap = float(_mpc_result.flow_cap_K)
                except (TypeError, ValueError):
                    _flow_cap = None
                if _flow_cap is not None and _flow_cap > 0.0:
                    _new_trv_calibration += _flow_cap
    else:
        self.real_trvs[entity_id].pop("calibration_balance", None)

    if _new_trv_calibration is None:
        return None

    _skip_post_adjustments = _calibration_mode == CalibrationMode.MPC_CALIBRATION

    _new_trv_calibration = float(_new_trv_calibration)

    if _calibration_mode == CalibrationMode.AGGRESIVE_CALIBRATION:
        if self.attr_hvac_action == HVACAction.HEATING:
            if _new_trv_calibration > -2.5:
                _new_trv_calibration -= 2.5

    if _calibration_mode == CalibrationMode.HEATING_POWER_CALIBRATION:
        if self.attr_hvac_action == HVACAction.HEATING:
            _valve_position = heating_power_valve_position(self, entity_id)
            _new_trv_calibration = _current_trv_calibration - (
                (self.real_trvs[entity_id]["local_calibration_min"] + _cur_trv_temp_f)
                * _valve_position
            )

    # Respecting tolerance in all calibration modes, delaying heat
    if not _skip_post_adjustments:
        if self.attr_hvac_action == HVACAction.IDLE:
            if _new_trv_calibration < 0.0:
                _new_trv_calibration += self.tolerance * 2.0

    _new_trv_calibration = fix_local_calibration(self, entity_id, _new_trv_calibration)

    _overheating_protection = self.real_trvs[entity_id]["advanced"].get(
        CONF_PROTECT_OVERHEATING, False
    )

    # Additional adjustment if overheating protection is enabled
    if not _skip_post_adjustments and _overheating_protection is True:
        if self.attr_hvac_action == HVACAction.IDLE:
            _new_trv_calibration += (
                _cur_external_temp - (_cur_target_temp + self.tolerance)
            ) * 8.0  # Reduced from 10.0 since we already add 2.0

    # Adjust based on the step size allowed by the local calibration entity
    _new_trv_calibration = round_by_step(_new_trv_calibration, _calibration_step)
    if _new_trv_calibration is None:
        return None

    # limit new setpoint within min/max of the TRV's range
    t_min = _convert_to_float(self.real_trvs[entity_id]["local_calibration_min"])
    t_max = _convert_to_float(self.real_trvs[entity_id]["local_calibration_max"])
    if t_min is None or t_max is None:
        return _new_trv_calibration
    t_min = float(t_min)
    t_max = float(t_max)
    _new_trv_calibration = max(t_min, min(_new_trv_calibration, t_max))

    _new_trv_calibration = _convert_to_float(_new_trv_calibration)
    if _new_trv_calibration is None:
        return None

    _new_trv_calibration = round(_new_trv_calibration, 1)
    _cur_external_temp = round(_cur_external_temp, 1)
    _cur_trv_temp_f = round(_cur_trv_temp_f, 1)
    _current_trv_calibration = round(_current_trv_calibration, 1)

    _logmsg = (
        "better_thermostat %s: %s - new local calibration: %s | external_temp: %s, "
        "trv_temp: %s, calibration: %s"
    )

    _LOGGER.debug(
        _logmsg,
        self.device_name,
        entity_id,
        _new_trv_calibration,
        _cur_external_temp,
        _cur_trv_temp_f,
        _current_trv_calibration,
    )

    return _new_trv_calibration


def calculate_calibration_setpoint(self, entity_id) -> float | None:
    """Calculate new setpoint for the TRV based on its own temperature measurement and the air temperature of the external sensor.

    This calibration is for devices with no local calibration option, it syncs the target temperature of the TRV to a new target
    temperature based on the current temperature of the external sensor.

    Parameters
    ----------
    self :
            self instance of better_thermostat

    Returns
    -------
    float
            new target temp with calibration
    """
    _context = "_calculate_calibration_setpoint()"

    def _convert_to_float(value):
        return convert_to_float(value, self.name, _context)

    if self.cur_temp is None or self.bt_target_temp is None:
        return None

    # Add tolerance check
    _cur_external_temp = float(self.cur_temp)
    _cur_target_temp = float(self.bt_target_temp)

    _calibration_mode = self.real_trvs[entity_id]["advanced"].get(
        "calibration_mode", CalibrationMode.DEFAULT
    )

    _cur_trv_temp_s = self.real_trvs[entity_id]["current_temperature"]
    _cur_trv_temp = _convert_to_float(_cur_trv_temp_s)

    _trv_temp_step_raw = self.real_trvs[entity_id]["target_temp_step"]
    _trv_temp_step = _convert_to_float(_trv_temp_step_raw)
    if _trv_temp_step is None or _trv_temp_step <= 0:
        _trv_temp_step = 0.5

    if _cur_trv_temp is None:
        return None

    _cur_trv_temp = float(_cur_trv_temp)

    _calibrated_setpoint = (_cur_target_temp - _cur_external_temp) + _cur_trv_temp

    _mpc_result = None
    _mpc_use_valve = False
    if _calibration_mode == CalibrationMode.MPC_CALIBRATION:
        _mpc_result, _mpc_use_valve = _compute_mpc_balance(self, entity_id)
        if _mpc_use_valve:
            _calibrated_setpoint = _cur_target_temp
        elif _mpc_result is not None:
            _mpc_setpoint = getattr(_mpc_result, "setpoint_eff_C", None)
            if isinstance(_mpc_setpoint, (int, float)):
                _calibrated_setpoint = float(_mpc_setpoint)
            else:
                _calibrated_setpoint = _cur_target_temp
                _mpc_percent = getattr(_mpc_result, "valve_percent", None)
                if isinstance(_mpc_percent, (int, float)):
                    _max_temp = _convert_to_float(self.real_trvs[entity_id]["max_temp"])
                    if _max_temp is not None:
                        _valve_fraction = max(
                            0.0, min(1.0, float(_mpc_percent) / 100.0)
                        )
                        _calibrated_setpoint = _cur_trv_temp + (
                            (float(_max_temp) - _cur_trv_temp) * _valve_fraction
                        )
        else:
            _calibrated_setpoint = _cur_target_temp
    else:
        self.real_trvs[entity_id].pop("calibration_balance", None)

    _skip_post_adjustments = _calibration_mode == CalibrationMode.MPC_CALIBRATION

    if _calibration_mode == CalibrationMode.AGGRESIVE_CALIBRATION:
        if self.attr_hvac_action == HVACAction.HEATING:
            if _calibrated_setpoint - _cur_trv_temp < 2.5:
                _calibrated_setpoint += 2.5

    if _calibration_mode == CalibrationMode.HEATING_POWER_CALIBRATION:
        if self.attr_hvac_action == HVACAction.HEATING:
            valve_position = heating_power_valve_position(self, entity_id)
            max_temp = _convert_to_float(self.real_trvs[entity_id]["max_temp"])
            if max_temp is not None:
                _calibrated_setpoint = _cur_trv_temp + (
                    (float(max_temp) - _cur_trv_temp) * valve_position
                )

    if _calibrated_setpoint is None:
        return None

    _calibrated_setpoint = float(_calibrated_setpoint)

    if not _skip_post_adjustments:
        if self.attr_hvac_action == HVACAction.IDLE:
            if _calibrated_setpoint - _cur_trv_temp > 0.0:
                _calibrated_setpoint -= self.tolerance * 2.0

    _calibrated_setpoint = fix_target_temperature_calibration(
        self, entity_id, _calibrated_setpoint
    )

    _overheating_protection = self.real_trvs[entity_id]["advanced"].get(
        CONF_PROTECT_OVERHEATING, False
    )

    # Additional adjustment if overheating protection is enabled
    if not _skip_post_adjustments and _overheating_protection is True:
        if self.attr_hvac_action == HVACAction.IDLE:
            _calibrated_setpoint -= (
                _cur_external_temp - (_cur_target_temp + self.tolerance)
            ) * 8.0  # Reduced from 10.0 since we already subtract 2.0

    _calibrated_setpoint = round_by_step(_calibrated_setpoint, _trv_temp_step)
    if _calibrated_setpoint is None:
        return None

    # limit new setpoint within min/max of the TRV's range
    t_min = _convert_to_float(self.real_trvs[entity_id]["min_temp"])
    t_max = _convert_to_float(self.real_trvs[entity_id]["max_temp"])
    if t_min is not None:
        _calibrated_setpoint = max(float(t_min), _calibrated_setpoint)
    if t_max is not None:
        _calibrated_setpoint = min(_calibrated_setpoint, float(t_max))

    _logmsg = (
        "better_thermostat %s: %s - new setpoint calibration: %s | external_temp: %s, "
        "target_temp: %s, trv_temp: %s"
    )

    _LOGGER.debug(
        _logmsg,
        self.device_name,
        entity_id,
        _calibrated_setpoint,
        _cur_external_temp,
        _cur_target_temp,
        _cur_trv_temp,
    )

    return _calibrated_setpoint
